package regression.bank.ml

import java.io.{FileInputStream, FileOutputStream}
import java.util.Date
import javax.xml.transform.stream.{StreamResult, StreamSource}

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.SparkSession
import org.dmg.pmml.FieldName
import org.jpmml.evaluator.ModelEvaluatorFactory
import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.ConverterUtil

import scala.collection.JavaConversions._

/**
  * 输入数据格式，rej字段为标注字段，（银行贷款申请及审核结果？）
  * a1cx,a1cy,a1sx,a1sy,a1rho,a1pop,a2cx,a2cy,a2sx,a2sy,a2rho,a2pop,a3cx,a3cy,a3sx,a3sy,a3rho,a3pop,temp,b1x,b1y,b1call,b1eff,b2x,b2y,b2call,b2eff,b3x,b3y,b3call,b3eff,mxql,rej,
  *0.413010,0.607442,0.332608,0.406812,-0.151224,1.525222,-0.144368,0.852368,0.412397,1.728169,-0.449231,4.078482,0.232042,-0.323190,0.792235,0.421474,-0.307503,3.086689,0.363949,0.441308,-0.276851,2.000000,1.974706,-0.776759,-0.783770,8.000000,0.603486,-0.997118,-0.502138,5.000000,1.169388,9.000000,0.049118,
  * -0.602384,0.350618,0.429196,0.414476,-0.124489,4.597991,0.579458,0.651134,0.104394,0.636356,-0.283787,3.546643,0.115860,0.409074,2.152997,0.758680,0.341127,1.478951,0.662488,0.462398,0.339673,6.000000,0.798979,-0.002820,-0.080542,2.000000,1.125542,-0.983397,-0.107632,5.000000,1.186039,7.000000,0.242579,
  * -0.322881,-0.538491,1.602260,0.039605,0.196023,1.909005,-0.675672,0.963618,0.147458,1.414008,0.495453,0.056459,-0.163151,0.350221,1.124090,1.398160,-0.456921,1.600723,0.650252,-0.247380,0.318002,3.000000,0.577355,-0.952645,-0.571600,5.000000,1.280392,0.771129,-0.665756,5.000000,1.024203,6.000000,0.000000,
  * -0.233570,-0.936451,1.710192,2.179527,0.438461,4.742055,-0.163625,-0.923273,0.597622,0.118409,0.229981,3.209085,-0.165046,0.012872,0.398148,1.335824,0.119910,13.070052,0.308221,-0.743841,0.258362,4.000000,0.760084,-0.198235,-0.205276,2.000000,0.509727,-0.579544,0.480094,6.000000,1.568492,7.000000,0.469045,
  *0.403126,0.313367,0.822382,1.393975,0.253435,9.398630,0.312528,0.288321,0.431867,0.110369,0.294665,1.274100,0.328350,-0.288962,0.067075,0.632938,0.148618,3.633846,0.233204,-0.685285,-0.758206,6.000000,1.170067,0.573352,0.315217,2.000000,0.622033,-0.134747,0.669948,3.000000,1.295913,9.000000,0.000000,
  * -0.017878,0.220771,0.037941,0.298548,-0.472770,3.762778,-0.795460,0.165907,0.421014,0.013787,-0.139091,0.381122,0.836916,-0.801391,2.242973,0.194025,-0.303852,0.636225,0.554068,0.602932,-0.772229,3.000000,1.063314,-0.387546,-0.253900,5.000000,0.597801,-0.631752,0.041766,3.000000,1.466912,7.000000,0.000000,
  *0.345700,0.658409,0.004313,0.041727,0.248306,1.476620,0.317824,-0.015015,0.870175,1.059934,0.275087,0.654086,-0.330889,0.908071,1.308776,0.143776,0.208710,5.227891,0.409048,0.649310,0.010654,7.000000,0.901486,0.368354,-0.136157,4.000000,0.869763,0.680687,-0.287984,8.000000,0.653114,9.000000,0.000000,
  *0.191042,0.954493,0.841096,1.255521,0.464016,1.099298,-0.314780,-0.110777,1.102842,0.046679,-0.466958,1.642563,-0.350668,0.958245,0.222133,2.157210,0.463938,8.345275,0.583435,-0.011537,0.864393,5.000000,0.824826,-0.385423,0.085826,7.000000,1.015563,-0.241577,-0.167008,7.000000,0.861025,8.000000,0.230277,
  *0.239531,-0.391718,0.317551,2.956340,0.122455,6.116613,-0.992386,-0.734185,0.797696,1.341425,0.430840,0.571166,0.643303,0.049932,0.964677,0.027338,0.427852,1.623983,0.642157,0.668291,-0.447812,4.000000,1.318010,0.323396,-0.949254,2.000000,0.971576,0.028994,0.384228,8.000000,1.500834,9.000000,0.041226,
  */

/**
  * Created by peibin on 2017/4/13.
  */
object BankLRDemo {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("BankLR")
      .master("local[*]")
      .getOrCreate()

    val featureFields = Array("a1cx", "a1cy", "a1sx", "a1sy", "a1rho", "a1pop", "a2cx", "a2cy", "a2sx", "a2sy", "a2rho", "a2pop", "a3cx", "a3cy", "a3sx", "a3sy", "a3rho", "a3pop", "temp", "b1x", "b1y", "b1call", "b1eff", "b2x", "b2y", "b2call", "b2eff", "b3x", "b3y", "b3call", "b3eff", "mxql")
    val sql = featureFields.map(x => {
      s"cast($x as double) $x"
    })
    val allSql = sql ++ Array(s"cast(rej as double) rej")


    val data = spark.read.option("header", "true").csv("""/Users/peibin/workspace/data-mining/regression-datasets/bank32nh.data""").selectExpr(allSql: _*)
    // data.printSchema()
    val Array(training, test) = data.randomSplit(Array(0.7, 0.3))

    val vectorAssember = new VectorAssembler()
    vectorAssember.setInputCols(featureFields)
    vectorAssember.setOutputCol("features")


    val lir = new LinearRegression()
      .setFeaturesCol("features")
      .setLabelCol("rej")

    val pipeline = new Pipeline().setStages(Array(vectorAssember, lir));

    /*
      使用交叉验证计算多组参数下的最佳模型
     */
    val paraMaps = new ParamGridBuilder()
      .addGrid(lir.maxIter, Array(1, 5, 10))
      .addGrid(lir.tol, Array(1E-6, 1E-5))
      .build()
    val eval = new RegressionEvaluator().setLabelCol("rej").setPredictionCol("prediction")
    val cv = new CrossValidator().setEstimator(pipeline).setEstimatorParamMaps(paraMaps).setEvaluator(eval)
    val model = cv.fit(training).bestModel


    //val model = pipeline.fit(training)


    model.transform(training).select("prediction", "rej").show(10)

    val evaluator = model.transform(test)
    //evaluator.printSchema()
    //  跑测试集
    val predictions = evaluator.select("prediction").rdd.map(_.getDouble(0))

    //  评估均方差
    val labels = evaluator.select("rej").rdd.map(_.getDouble(0))
    val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
    println(s"Root mean squared error (RMSE): $RMSE")

    // 构造用于验证pmml持久化 再加载的测试数据
    val singleTestRow = training.collect().head
    val singleTestDF = spark.createDataFrame(Seq(singleTestRow), training.schema)
    val singleTestMap = featureFields.map(x => new FieldName(x) -> singleTestRow.getAs[Double](x)).toMap
    // 用spark 训练处理的model 进行预估
    model.transform(singleTestDF).select("prediction", "rej").show()
    spark.stop()


    // 将model转成pmml后持久化模型到硬盘
    val pmml = ConverterUtil.toPMML(training.schema, model.asInstanceOf[PipelineModel]);
    val persistPath = "/tmp/lr"
    JAXBUtil.marshalPMML(pmml, new StreamResult(new FileOutputStream(persistPath)));
    // JAXBUtil.marshalPMML(pmml, new StreamResult(System.out));
    // 加载pmml到内存，并创建model实例
    val newPmml = JAXBUtil.unmarshalPMML(new StreamSource(new FileInputStream(persistPath)))
    val newModel = ModelEvaluatorFactory.newInstance.newModelEvaluator(newPmml)
    // 使用加载的model进行预估
    println(new Date + " : " + newModel.evaluate(singleTestMap))

  }


}
